import math


class AdaptiveClientHyperparams:
    def __init__(
        self,
        lr=0.01,
        initial_temp=0.5,
        min_temp=0.05,
        max_temp=1.0,
        patience=2,
        temp_adjust_factor=0.8,
    ):

        self.lr = lr

        self.temp = initial_temp
        self.min_temp = min_temp
        self.max_temp = max_temp
        self.temp_adjust_factor = temp_adjust_factor

        self.patience = patience
        self.performance_history = []
        self.stagnation_count = 0
        self.het_score = 0.0  # Heterogeneity score for the client

    def update_temp_from_heterogeneity(self, het_score):
        """Initialize temperature based on client's heterogeneity score"""
        self.het_score = het_score
        self.temp = self.max_temp * math.exp(-2 * het_score)
        self.temp = min(max(self.temp, self.min_temp), self.max_temp)

    def record_performance(self, accuracy):
        """Record performance and adjust temperature if needed"""
        self.performance_history.append(accuracy)

        if len(self.performance_history) >= 3:

            if self.performance_history[-1] <= self.performance_history[-2]:
                self.stagnation_count += 1
            else:
                self.stagnation_count = 0

            if self.stagnation_count >= self.patience:
                if self.temp > self.min_temp * 1.1:
                    self.temp *= self.temp_adjust_factor
                    self.temp = max(self.temp, self.min_temp)
                    self.stagnation_count = 0
                    return True, "decrease-temp", self.temp

        return False, "unchanged", None

    def get_lr(self):
        return self.lr

    def get_temp(self):
        return self.temp
